# Copyright 2017 Google Inc. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Attention-based sequence-to-sequence model with dynamic RNN support."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function

import numpy as np
import tensorflow as tf

import model
import model_helper
import attention_utils

from lottery import lottery

__all__ = ["AttentionModel"]


class AttentionModel(model.Model):
  """Sequence-to-sequence dynamic model with attention.

  This class implements a multi-layer recurrent neural network as encoder,
  and an attention-based decoder. This is the same as the model described in
  (Luong et al., EMNLP'2015) paper: https://arxiv.org/pdf/1508.04025v5.pdf.
  This class also allows to use GRU cells in addition to LSTM cells with
  support for dropout.
  """

  def __init__(self,
               hparams,
               mode,
               features,
               scope=None,
               extra_args=None):
    self.has_attention = hparams.attention_architecture and hparams.attention

    # Set attention_mechanism_fn
    if self.has_attention:
      if extra_args and extra_args.attention_mechanism_fn:
        self.attention_mechanism_fn = extra_args.attention_mechanism_fn
      else:
        self.attention_mechanism_fn = create_attention_mechanism

    super(AttentionModel, self).__init__(
        hparams=hparams,
        mode=mode,
        features=features,
        scope=scope,
        extra_args=extra_args)

  def _prepare_beam_search_decoder_inputs(
      self, beam_width, memory, source_sequence_length, encoder_state):
    memory = tf.contrib.seq2seq.tile_batch(
        memory, multiplier=beam_width)
    source_sequence_length = tf.contrib.seq2seq.tile_batch(
        source_sequence_length, multiplier=beam_width)
    encoder_state = tf.contrib.seq2seq.tile_batch(
        encoder_state, multiplier=beam_width)
    batch_size = self.batch_size * beam_width
    return memory, source_sequence_length, encoder_state, batch_size

  def _build_decoder_cell(self, hparams, encoder_outputs, encoder_state,
                          source_sequence_length):
    """Build a RNN cell with attention mechanism that can be used by decoder."""
    # No Attention
    if not self.has_attention:
      return super(AttentionModel, self)._build_decoder_cell(
          hparams, encoder_outputs, encoder_state, source_sequence_length)
    elif hparams.attention_architecture != "standard":
      raise ValueError(
          "Unknown attention architecture %s" % hparams.attention_architecture)

    num_units = hparams.num_units
    num_layers = self.num_decoder_layers
    num_residual_layers = self.num_decoder_residual_layers
    infer_mode = hparams.infer_mode

    dtype = self.dtype

    # Ensure memory is batch-major
    if self.time_major:
      memory = tf.transpose(encoder_outputs, [1, 0, 2])
    else:
      memory = encoder_outputs

    if (self.mode == tf.contrib.learn.ModeKeys.INFER and
        infer_mode == "beam_search"):
      memory, source_sequence_length, encoder_state, batch_size = (
          self._prepare_beam_search_decoder_inputs(
              hparams.beam_width, memory, source_sequence_length,
              encoder_state))
    else:
      batch_size = self.batch_size

    # Attention
    attention_mechanism = self.attention_mechanism_fn(
        hparams.attention, num_units, memory, source_sequence_length, self.mode)

    cell = model_helper.create_rnn_cell(
        unit_type=hparams.unit_type,
        num_units=num_units,
        num_layers=num_layers,
        num_residual_layers=num_residual_layers,
        forget_bias=hparams.forget_bias,
        dropout=hparams.dropout,
        mode=self.mode,
        single_cell_fn=self.single_cell_fn,
        global_step=self.global_step)

    # Only generate alignment in greedy INFER mode.
    alignment_history = (self.mode == tf.contrib.learn.ModeKeys.INFER and
                         infer_mode != "beam_search")
    cell = tf.contrib.seq2seq.AttentionWrapper(
        cell,
        attention_mechanism,
        attention_layer_size=num_units,
        alignment_history=alignment_history,
        output_attention=hparams.output_attention,
        name="attention")

    if hparams.pass_hidden_state:
      decoder_initial_state = cell.zero_state(batch_size, dtype).clone(
          cell_state=encoder_state)
    else:
      decoder_initial_state = cell.zero_state(batch_size, dtype)

    return cell, decoder_initial_state


def create_attention_mechanism(attention_option, num_units, memory,
                               source_sequence_length, mode):
  """Create attention mechanism based on the attention_option."""
  del mode  # unused

  score_mask_value = tf.convert_to_tensor(
      tf.as_dtype(memory.dtype).as_numpy_dtype(-np.inf))

  # Mechanism
  if attention_option == "luong":
    attention_mechanism = tf.contrib.seq2seq.LuongAttention(
        num_units, memory, memory_sequence_length=source_sequence_length,
        score_mask_value=score_mask_value)
  elif attention_option == "scaled_luong":
    attention_mechanism = tf.contrib.seq2seq.LuongAttention(
        num_units,
        memory,
        memory_sequence_length=source_sequence_length,
        score_mask_value=score_mask_value,
        scale=True)
  elif attention_option == "bahdanau":
    attention_mechanism = tf.contrib.seq2seq.BahdanauAttention(
        num_units, memory, memory_sequence_length=source_sequence_length,
        score_mask_value=score_mask_value)
  elif attention_option == "normed_bahdanau":
    attention_mechanism = attention_utils.BahdanauAttention(
        num_units,
        memory,
        memory_sequence_length=source_sequence_length,
        score_mask_value=score_mask_value,
        normalize=True,
        dtype=memory.dtype)
  else:
    raise ValueError("Unknown attention option %s" % attention_option)

  return attention_mechanism
